{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "复习numpy\n",
    "--------------\n",
    "\n",
    "我们这里实现一个全连接的激活为ReLU的网络，它只有一个隐层，没有bias，用于回归预测一个值，loss是计算实际值和预测值的欧氏距离。\n",
    "\n",
    "我们这里完全使用numpy手动的进行前向和后向计算。\n",
    "\n",
    "numpy数组就是一个n维的数值，它并不知道任何关于深度学习、梯度下降或者计算图的东西，它只是进行数值运算。\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 33646753.85632531\n",
      "1 27453271.08056476\n",
      "2 24080728.07197475\n",
      "3 20012098.00318867\n",
      "4 15020097.40268158\n",
      "5 10091813.961896304\n",
      "6 6341133.404644929\n",
      "7 3910062.1217965647\n",
      "8 2488779.0392809846\n",
      "9 1676342.0555654871\n",
      "10 1205096.8024708242\n",
      "11 915995.4516851017\n",
      "12 726284.396605152\n",
      "13 593105.7328345232\n",
      "14 494089.0852438557\n",
      "15 417139.10831823596\n",
      "16 355578.76236067386\n",
      "17 305310.7595813564\n",
      "18 263623.60571311344\n",
      "19 228668.0907029406\n",
      "20 199128.49472756893\n",
      "21 174049.81959830096\n",
      "22 152653.4961341839\n",
      "23 134268.9237992906\n",
      "24 118418.86725722818\n",
      "25 104720.75201967722\n",
      "26 92834.60862786917\n",
      "27 82481.04671131056\n",
      "28 73439.96087784009\n",
      "29 65528.99758426403\n",
      "30 58583.28115482667\n",
      "31 52465.886487346346\n",
      "32 47063.94844026916\n",
      "33 42284.91975667222\n",
      "34 38047.84042800703\n",
      "35 34285.54548707235\n",
      "36 30937.511824526584\n",
      "37 27954.474727795743\n",
      "38 25296.420806012582\n",
      "39 22919.451085092278\n",
      "40 20789.087979882876\n",
      "41 18877.836070129975\n",
      "42 17160.75784321333\n",
      "43 15615.917956880401\n",
      "44 14223.92132894742\n",
      "45 12967.805813690054\n",
      "46 11834.965045754612\n",
      "47 10810.21305701997\n",
      "48 9882.8559380938\n",
      "49 9041.9810280073\n",
      "50 8278.968162056883\n",
      "51 7586.093751642402\n",
      "52 6956.009921659297\n",
      "53 6382.954802180364\n",
      "54 5860.849234328865\n",
      "55 5385.169889084393\n",
      "56 4951.041776152293\n",
      "57 4554.779685101734\n",
      "58 4192.744135825489\n",
      "59 3861.5620330435404\n",
      "60 3558.575568002799\n",
      "61 3281.0395107493077\n",
      "62 3026.7057109497837\n",
      "63 2793.4673438838\n",
      "64 2579.540610184513\n",
      "65 2383.112774989586\n",
      "66 2202.6219497691577\n",
      "67 2036.651349274121\n",
      "68 1883.9874194999397\n",
      "69 1743.543788585047\n",
      "70 1614.2352805477078\n",
      "71 1495.157544312757\n",
      "72 1385.3471829101436\n",
      "73 1284.1145038334066\n",
      "74 1190.6561237655994\n",
      "75 1104.4641342071607\n",
      "76 1024.8533282707592\n",
      "77 951.2702038441282\n",
      "78 883.3062246724242\n",
      "79 820.4153291438439\n",
      "80 762.2712638200151\n",
      "81 708.4410314230372\n",
      "82 658.6280068386973\n",
      "83 612.4840786589389\n",
      "84 569.7567847984675\n",
      "85 530.1405625582308\n",
      "86 493.407254385583\n",
      "87 459.35207253385965\n",
      "88 427.77590615569954\n",
      "89 398.4705458788942\n",
      "90 371.25361991002046\n",
      "91 345.9856534757829\n",
      "92 322.51343294265325\n",
      "93 300.71140570013074\n",
      "94 280.44213565786845\n",
      "95 261.60232119729324\n",
      "96 244.07747551123083\n",
      "97 227.77997334455205\n",
      "98 212.61360962023667\n",
      "99 198.50961458204173\n",
      "100 185.37523815037343\n",
      "101 173.1419420883252\n",
      "102 161.7444086307533\n",
      "103 151.1294867080377\n",
      "104 141.2418167168305\n",
      "105 132.0190084423392\n",
      "106 123.42395927159004\n",
      "107 115.40650041982232\n",
      "108 107.9331931265024\n",
      "109 100.95645429843884\n",
      "110 94.44764729886292\n",
      "111 88.37550696598488\n",
      "112 82.70965331073648\n",
      "113 77.41507577079909\n",
      "114 72.47115687898862\n",
      "115 67.85576329889035\n",
      "116 63.544450243267804\n",
      "117 59.51607914190761\n",
      "118 55.74955211290852\n",
      "119 52.231519899746964\n",
      "120 48.94166589919251\n",
      "121 45.864513670391275\n",
      "122 42.98756023211579\n",
      "123 40.297620619836636\n",
      "124 37.77944599570842\n",
      "125 35.42434915584229\n",
      "126 33.22032044610702\n",
      "127 31.15774177267508\n",
      "128 29.226589549802874\n",
      "129 27.419306720187347\n",
      "130 25.727286020788984\n",
      "131 24.14145457970656\n",
      "132 22.656144402407016\n",
      "133 21.265483588093062\n",
      "134 19.96219818899514\n",
      "135 18.740646692843463\n",
      "136 17.596115202718632\n",
      "137 16.523504720132518\n",
      "138 15.518003373288167\n",
      "139 14.574972128306609\n",
      "140 13.690940445078125\n",
      "141 12.861919935012708\n",
      "142 12.084343273855065\n",
      "143 11.35467488579186\n",
      "144 10.670435223972106\n",
      "145 10.028593982086083\n",
      "146 9.426350173250528\n",
      "147 8.86089583156576\n",
      "148 8.330302140505426\n",
      "149 7.832016693428557\n",
      "150 7.364434504327357\n",
      "151 6.925240852246151\n",
      "152 6.5130157968528755\n",
      "153 6.125703972577632\n",
      "154 5.7619569900207726\n",
      "155 5.4204849655102\n",
      "156 5.099511249077722\n",
      "157 4.798019228910501\n",
      "158 4.514613743742874\n",
      "159 4.2485534816637704\n",
      "160 3.9983117862897037\n",
      "161 3.7632063624782974\n",
      "162 3.5421547163161606\n",
      "163 3.334564409459553\n",
      "164 3.139215996763734\n",
      "165 2.9555112699209403\n",
      "166 2.7828052136929147\n",
      "167 2.6203955960283816\n",
      "168 2.467625916176105\n",
      "169 2.323922802598629\n",
      "170 2.18882086197181\n",
      "171 2.0616448974286965\n",
      "172 1.942024020381518\n",
      "173 1.829487275801244\n",
      "174 1.7236116429435369\n",
      "175 1.6239333671796419\n",
      "176 1.5301287918104347\n",
      "177 1.44188173367119\n",
      "178 1.3587846700284987\n",
      "179 1.2805611024975923\n",
      "180 1.206938433936477\n",
      "181 1.1376694262619151\n",
      "182 1.0723723867568438\n",
      "183 1.0108891587061977\n",
      "184 0.9530026660205367\n",
      "185 0.8984933344044139\n",
      "186 0.8471419816923619\n",
      "187 0.7987625653307084\n",
      "188 0.75321798877118\n",
      "189 0.710304003755077\n",
      "190 0.6698661834533903\n",
      "191 0.6317632350565083\n",
      "192 0.5958780105917416\n",
      "193 0.5620438193177038\n",
      "194 0.5301629236128438\n",
      "195 0.5001214657644013\n",
      "196 0.47181574484138933\n",
      "197 0.44512709938596284\n",
      "198 0.4199739822018371\n",
      "199 0.3962758813558982\n",
      "200 0.37391740200800294\n",
      "201 0.3528396720106114\n",
      "202 0.33296411023696443\n",
      "203 0.31423636512885483\n",
      "204 0.29656593798411324\n",
      "205 0.2799014845917933\n",
      "206 0.2641884355352626\n",
      "207 0.2493739963625293\n",
      "208 0.23539711109640377\n",
      "209 0.22221545078177463\n",
      "210 0.20978128012617805\n",
      "211 0.19805291636034555\n",
      "212 0.18698698446669054\n",
      "213 0.17654606673316697\n",
      "214 0.16670108215077523\n",
      "215 0.15740921636436464\n",
      "216 0.14863924940228926\n",
      "217 0.1403695984606317\n",
      "218 0.13256675099197637\n",
      "219 0.12519626857814992\n",
      "220 0.1182414632456814\n",
      "221 0.11168028054567249\n",
      "222 0.10548958194894455\n",
      "223 0.09964113716575237\n",
      "224 0.09412125103139254\n",
      "225 0.0889120921641555\n",
      "226 0.08399617581690047\n",
      "227 0.07935281527085905\n",
      "228 0.07496868497224299\n",
      "229 0.07083059984233567\n",
      "230 0.06692397275499679\n",
      "231 0.06323375656182109\n",
      "232 0.05974916562166152\n",
      "233 0.05646039068932693\n",
      "234 0.05335377941245477\n",
      "235 0.0504193682671884\n",
      "236 0.047649010749446566\n",
      "237 0.04503378876982777\n",
      "238 0.042561758054085366\n",
      "239 0.0402267967133619\n",
      "240 0.0380213204852353\n",
      "241 0.035938562194835974\n",
      "242 0.03397042768572205\n",
      "243 0.032110985859249375\n",
      "244 0.030354786511130064\n",
      "245 0.028695818282422892\n",
      "246 0.02712785088136848\n",
      "247 0.025646214979112375\n",
      "248 0.0242470595717837\n",
      "249 0.022924749137503447\n",
      "250 0.02167489109179123\n",
      "251 0.02049380016602744\n",
      "252 0.019378153441362202\n",
      "253 0.01832383011380395\n",
      "254 0.017326879709151505\n",
      "255 0.01638470215353354\n",
      "256 0.015494855820568717\n",
      "257 0.014653975183814854\n",
      "258 0.01385843632180708\n",
      "259 0.0131064231384097\n",
      "260 0.01239596952104221\n",
      "261 0.011724202082079683\n",
      "262 0.011088988460757135\n",
      "263 0.010488580311229106\n",
      "264 0.00992122460333594\n",
      "265 0.009384613197286546\n",
      "266 0.008877191927974426\n",
      "267 0.00839744736841792\n",
      "268 0.007944127525636445\n",
      "269 0.007515251409933891\n",
      "270 0.007109735596496277\n",
      "271 0.006726270698647757\n",
      "272 0.006363825026170815\n",
      "273 0.006020896393100611\n",
      "274 0.0056966259807518975\n",
      "275 0.005389982814497113\n",
      "276 0.005100077570626674\n",
      "277 0.00482575364733916\n",
      "278 0.004566504971664636\n",
      "279 0.004321149801363151\n",
      "280 0.004089141264481615\n",
      "281 0.0038696086589989046\n",
      "282 0.0036619227055097988\n",
      "283 0.0034655105412828425\n",
      "284 0.003279799596439323\n",
      "285 0.003104023033714323\n",
      "286 0.00293771860211258\n",
      "287 0.002780439664591278\n",
      "288 0.002631706960236262\n",
      "289 0.0024909025615654505\n",
      "290 0.002357675763046902\n",
      "291 0.0022316535968041904\n",
      "292 0.0021124728030671534\n",
      "293 0.0019996479478594396\n",
      "294 0.0018928863516532962\n",
      "295 0.001791860779027191\n",
      "296 0.0016963209468140855\n",
      "297 0.001605884633201364\n",
      "298 0.0015202936167669472\n",
      "299 0.0014392903706875785\n",
      "300 0.001362705315306394\n",
      "301 0.0012901979863674646\n",
      "302 0.0012215369910176846\n",
      "303 0.0011565547769683384\n",
      "304 0.0010950761603031277\n",
      "305 0.0010368901553818918\n",
      "306 0.0009818030120747746\n",
      "307 0.0009296578989754915\n",
      "308 0.0008803073919704202\n",
      "309 0.000833611336177691\n",
      "310 0.0007894005229484014\n",
      "311 0.0007475366128488202\n",
      "312 0.0007079122115480746\n",
      "313 0.0006704250425189281\n",
      "314 0.0006349165922570797\n",
      "315 0.0006012985412529122\n",
      "316 0.0005694745596652782\n",
      "317 0.000539356476246567\n",
      "318 0.0005108369603943474\n",
      "319 0.00048383293550072345\n",
      "320 0.00045826172141416035\n",
      "321 0.0004340599812180609\n",
      "322 0.00041114894618829586\n",
      "323 0.0003894446545374457\n",
      "324 0.00036890816081368186\n",
      "325 0.0003494499492684148\n",
      "326 0.00033103270372670894\n",
      "327 0.00031358384023738703\n",
      "328 0.0002970599852225008\n",
      "329 0.0002814116396949542\n",
      "330 0.00026660163468053457\n",
      "331 0.0002525708538685205\n",
      "332 0.00023928106056012136\n",
      "333 0.0002266961131385647\n",
      "334 0.00021477987258950498\n",
      "335 0.00020349414653085893\n",
      "336 0.00019280236280509524\n",
      "337 0.00018267483833965087\n",
      "338 0.0001730835388314521\n",
      "339 0.00016400432892495144\n",
      "340 0.00015539881520586386\n",
      "341 0.00014724766371334718\n",
      "342 0.00013952719441889499\n",
      "343 0.0001322165924471408\n",
      "344 0.00012529121141655628\n",
      "345 0.00011872843051004742\n",
      "346 0.00011251127762187287\n",
      "347 0.00010662285975896466\n",
      "348 0.00010104650535832993\n",
      "349 9.576206408603454e-05\n",
      "350 9.07571043511797e-05\n",
      "351 8.601262656497376e-05\n",
      "352 8.152004666268085e-05\n",
      "353 7.726231644616293e-05\n",
      "354 7.322757183099054e-05\n",
      "355 6.940512717538728e-05\n",
      "356 6.578373227035533e-05\n",
      "357 6.235335267722247e-05\n",
      "358 5.910145373555888e-05\n",
      "359 5.6020065629918975e-05\n",
      "360 5.3100876625948443e-05\n",
      "361 5.033501785023049e-05\n",
      "362 4.7713802164623126e-05\n",
      "363 4.5229691393377205e-05\n",
      "364 4.287560469902211e-05\n",
      "365 4.064494340177099e-05\n",
      "366 3.8531771620813476e-05\n",
      "367 3.652820494837491e-05\n",
      "368 3.462958726384474e-05\n",
      "369 3.2829946107443874e-05\n",
      "370 3.112483473537059e-05\n",
      "371 2.9508968042992977e-05\n",
      "372 2.7976982796534345e-05\n",
      "373 2.6525086534102042e-05\n",
      "374 2.5148856191966687e-05\n",
      "375 2.3845032458659837e-05\n",
      "376 2.260917722028284e-05\n",
      "377 2.1437861822316936e-05\n",
      "378 2.0327026555894084e-05\n",
      "379 1.9274091198377828e-05\n",
      "380 1.8276439386940537e-05\n",
      "381 1.733035626532829e-05\n",
      "382 1.6433412252919803e-05\n",
      "383 1.5583264739110337e-05\n",
      "384 1.4777381477040145e-05\n",
      "385 1.4013634630000862e-05\n",
      "386 1.3289388912531371e-05\n",
      "387 1.2602698536218763e-05\n",
      "388 1.1951684858889949e-05\n",
      "389 1.1334633820035646e-05\n",
      "390 1.0749697246078858e-05\n",
      "391 1.0194872470726808e-05\n",
      "392 9.668890692808508e-06\n",
      "393 9.170196839248003e-06\n",
      "394 8.697584347606335e-06\n",
      "395 8.249297563579144e-06\n",
      "396 7.824221831956458e-06\n",
      "397 7.42119949383376e-06\n",
      "398 7.03900506114464e-06\n",
      "399 6.676788848754881e-06\n",
      "400 6.333149633049269e-06\n",
      "401 6.0072837572168314e-06\n",
      "402 5.698290718735306e-06\n",
      "403 5.405286860841471e-06\n",
      "404 5.127554082399898e-06\n",
      "405 4.864072184330589e-06\n",
      "406 4.614316842979953e-06\n",
      "407 4.3773287844383165e-06\n",
      "408 4.152585108029746e-06\n",
      "409 3.939512401548747e-06\n",
      "410 3.7373532932882024e-06\n",
      "411 3.5456380292039585e-06\n",
      "412 3.3637879328135562e-06\n",
      "413 3.1913463298450933e-06\n",
      "414 3.0278280264082943e-06\n",
      "415 2.8726917702935804e-06\n",
      "416 2.7255464008753062e-06\n",
      "417 2.5859722368135234e-06\n",
      "418 2.453591336897596e-06\n",
      "419 2.3280614490235323e-06\n",
      "420 2.208949456803035e-06\n",
      "421 2.095951940451133e-06\n",
      "422 1.9887776726261266e-06\n",
      "423 1.8871110317526773e-06\n",
      "424 1.7907131472310115e-06\n",
      "425 1.6992184529946704e-06\n",
      "426 1.6124319357167558e-06\n",
      "427 1.5300962806816498e-06\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "428 1.451992367164433e-06\n",
      "429 1.3779210821161447e-06\n",
      "430 1.3076173662682819e-06\n",
      "431 1.240924839728358e-06\n",
      "432 1.1776450252452056e-06\n",
      "433 1.117618407562209e-06\n",
      "434 1.0606845366603025e-06\n",
      "435 1.0066465953464284e-06\n",
      "436 9.553852296428742e-07\n",
      "437 9.06760319941981e-07\n",
      "438 8.605999561195762e-07\n",
      "439 8.168207355867012e-07\n",
      "440 7.752650309969766e-07\n",
      "441 7.358335325904279e-07\n",
      "442 6.984182556189208e-07\n",
      "443 6.629139398230205e-07\n",
      "444 6.292344437323959e-07\n",
      "445 5.972736160506238e-07\n",
      "446 5.669366937675013e-07\n",
      "447 5.381468211347913e-07\n",
      "448 5.108301346590006e-07\n",
      "449 4.849100521774883e-07\n",
      "450 4.60313435199732e-07\n",
      "451 4.3696401610213727e-07\n",
      "452 4.1480617322114424e-07\n",
      "453 3.9377687780501087e-07\n",
      "454 3.7382564511017e-07\n",
      "455 3.5488954234910913e-07\n",
      "456 3.3691317307948826e-07\n",
      "457 3.1985133777856254e-07\n",
      "458 3.0365884476756654e-07\n",
      "459 2.8829064135645517e-07\n",
      "460 2.737087106467845e-07\n",
      "461 2.598621944334112e-07\n",
      "462 2.4672072982228823e-07\n",
      "463 2.3424726777596949e-07\n",
      "464 2.224080388842777e-07\n",
      "465 2.1117328135298994e-07\n",
      "466 2.0050538552481254e-07\n",
      "467 1.9037900041269442e-07\n",
      "468 1.807665339419012e-07\n",
      "469 1.7164154906884758e-07\n",
      "470 1.629834105854412e-07\n",
      "471 1.5476422271991237e-07\n",
      "472 1.4695883542885632e-07\n",
      "473 1.395484897167769e-07\n",
      "474 1.3251363234126593e-07\n",
      "475 1.2583563103270892e-07\n",
      "476 1.1949702403178958e-07\n",
      "477 1.1347783765005186e-07\n",
      "478 1.0776286667853162e-07\n",
      "479 1.0233730940443205e-07\n",
      "480 9.718628269648595e-08\n",
      "481 9.229746779599958e-08\n",
      "482 8.76548974330723e-08\n",
      "483 8.324596315025764e-08\n",
      "484 7.906028597295347e-08\n",
      "485 7.508590557448653e-08\n",
      "486 7.131301936609959e-08\n",
      "487 6.773074159854755e-08\n",
      "488 6.432843933232014e-08\n",
      "489 6.10979215158955e-08\n",
      "490 5.8030762011817786e-08\n",
      "491 5.511795599402395e-08\n",
      "492 5.235331551484216e-08\n",
      "493 4.972663122014098e-08\n",
      "494 4.723270328504848e-08\n",
      "495 4.4864279031499015e-08\n",
      "496 4.2615241561940016e-08\n",
      "497 4.0479793017478226e-08\n",
      "498 3.8451933334601825e-08\n",
      "499 3.6525884245304164e-08\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "# N是batch size；D_in是输入大小\n",
    "# H是隐层的大小；D_out是输出大小。\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "# 随机产生输入与输出\n",
    "x = np.random.randn(N, D_in)\n",
    "y = np.random.randn(N, D_out)\n",
    "\n",
    "# 随机初始化参数\n",
    "w1 = np.random.randn(D_in, H)\n",
    "w2 = np.random.randn(H, D_out)\n",
    "\n",
    "learning_rate = 1e-6\n",
    "for t in range(500):\n",
    "    # 前向计算y\n",
    "    h = x.dot(w1)\n",
    "    h_relu = np.maximum(h, 0)\n",
    "    y_pred = h_relu.dot(w2)\n",
    "\n",
    "    # 计算loss\n",
    "    loss = np.square(y_pred - y).sum()\n",
    "    print(t, loss)\n",
    "\n",
    "    # 反向计算梯度 \n",
    "    grad_y_pred = 2.0 * (y_pred - y)\n",
    "    grad_w2 = h_relu.T.dot(grad_y_pred)\n",
    "    grad_h_relu = grad_y_pred.dot(w2.T)\n",
    "    grad_h = grad_h_relu.copy()\n",
    "    grad_h[h < 0] = 0\n",
    "    grad_w1 = x.T.dot(grad_h)\n",
    "\n",
    "    # 更新参数\n",
    "    w1 -= learning_rate * grad_w1\n",
    "    w2 -= learning_rate * grad_w2"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.6-env",
   "language": "python",
   "name": "py3.6-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
